
# Copyright 2014 Nervana Systems Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

[-
    our $int16;
    our $prefix  = 'h';
    our $shareI  = 128;
    our $shareF  = 32;
    our $stepI   = 32;
    our $stepF   = 16;
    our $remapI  = 1;
    our $convert = $int16 ? 'I2F.F32.S16' : 'F2F.F32.F16';
    sub convert {return $convert;}
-]

<INCLUDE file="xconv_xprop_common.sass"/>

<CONSTANT_MAPPING>
    addr_zero : 4x<32*8*2 + 128*8*2 + 0>
    szShareF  : (32*8)
    szShareI  : (128*8)

    addr_zero  : 4x<32*8*2 + 128*8*2 + 0>
    addr_mpqk  : 4x<32*8*2 + 128*8*2 + 4>
    addr_m     : 4x<32*8*2 + 128*8*2 + 4>
    addr_p     : 4x<32*8*2 + 128*8*2 + 5>
    addr_q     : 4x<32*8*2 + 128*8*2 + 6>
    addr_k     : 4x<32*8*2 + 128*8*2 + 7>
    addr_szLut : 4x<32*8*2 + 128*8*2 + 8>
    addr_lut   : 4x<32*8*2 + 128*8*2 + 10>

[+ params() +]

</CONSTANT_MAPPING>

<REGISTER_MAPPING>

     3, 2,11,10,19,18,27,26 : cx<0-7>y0
     7, 6,15,14,23,22,31,30 : cx<0-7>y1
     1, 0, 9, 8,17,16,25,24 : cx<0-7>y2
     5, 4,13,12,21,20,29,28 : cx<0-7>y3
    35,34,43,42,51,50,59,58 : cx<0-7>y4
    39,38,47,46,55,54,63,62 : cx<0-7>y5
    33,32,41,40,49,48,57,56 : cx<0-7>y6
    37,36,45,44,53,52,61,60 : cx<0-7>y7

       0-63 : czero<00-63>

      64-67 : mpqk<0-3>
      64-69 : m, p, q
      64-69 : idx_M, idx_P, idx_Q, idx_K, tidY, negOne
     70-113 ~ tid1, tidIX, tidFX, idx_MPQk, idx_PQk, idx_Qk, idx_k, magic_PQk, magic_Qk, neg_PQk, neg_Qk, neg_k, div1, div2, div3, idx_P2, idx_Q2, q1, q2
     70-113 ~ mask_shr, rst, lutStore, lutStore2, warp_count, mt, pr, qs, dep_thd_mask, dep_thd_bits, dep_thd_cnt, t, r, s, rs, x, y, z, ballot, warp_slices, partial, endCRST, str_d, str_h, str_w, rst_prime, x_prime, y_prime, z_prime

      64-79 : j0Ix<0-7>, j0Fy<0-7>
      80-95 : j1Ix<0-7>, j1Fy<0-7>

      96-99 : trackI<0-1>, trackF<0-1>

    100-103 : load0I<0-3>
    100-103 : store0I<0-3>
    104-107 : store0I<4-7>

    108-111 : load1I<0-3>
    108-111 : store1I<0-3>
    104-107 : store1I<4-7>

    112-113 : loadF<0-1>
    104-107 : storeF<0-3>

    114-115 : sliceI, sliceF
    114-115 : sliceIF<0-1>

    116-140 ~ writeFs, writeIs, offsetIn, offsetFk, posCRST, posCRSTf, channel, lutSize, lutSizeRcp, lutOffset, offsetI, offsetF, offsetIc, offsetFc
    141-155 ~ readFs, readIs, swapBuf, tid, idx_N, tid7, tid1_7, tid32, tid32_1

    72-91   : cs<0-7>, c<0-3>, b<0-7>
    72-83   ~ x<0-7>
    92-99   : out<0-7>
   100-101  : Out<0-1>
   102-103  : Sum<0-1>
   104-140  ~ writeCs, readCs, alpha, k, n, sum<0-3>, offset, out_offset, bsum_offset, tidOX, tidOY, tidOX2, preds, one

</REGISTER_MAPPING>

--:-:1:-:1      S2R tid,      SR_TID.X;
--:-:2:-:1      S2R idx_MPQk, SR_CTAID.X;
--:-:3:-:1      S2R idx_K,    SR_CTAID.Y;
--:-:4:-:1      S2R idx_N,    SR_CTAID.Z;

<SCHEDULE_BLOCK>
01:-:-:-:1      ISETP.GE.AND P0, PT, tid, 32, PT;

[+ load_zeros() +]

[+ get_mpqk() +]

// tidIX = (tid & 7) << 3
// tidFX = (tid & 7) << 2

// tidY = tid >> 3
--:-:-:-:1      LOP.AND tid7,  tid,  7;
--:-:-:-:1      SHL     tidIX, tid7, 3;
--:-:-:-:1      SHL     tidFX, tid7, 2;
--:-:-:-:1      SHR.U32 tidY,  tid,  3;

// trackF += blkF*32 + tidFX
--:-:-:-:1      ISCADD  offsetFk, idx_K, tidFX, 5;

// trackI += blkI*128 + tidIX
08:-:-:-:1      ISCADD  offsetIn, idx_N, tidIX, 7;

// writeFs = (32*tidY + tidFX) * 4
--:-:-:-:1      ISCADD  writeFs, tidY, tidFX, 5;
--:-:-:-:1      SHL     writeFs, writeFs, 2;

// Remap the IX dim to avoid bank conflicts when storing to shared

// writeIs = (128*tidY + tidFX) * 4
--:-:-:-:1      ISCADD  writeIs, tidY, tidFX, 7;
--:-:-:-:1      ISCADD  writeIs, writeIs, 4x<szShareF>, 2;

--:-:-:-:1      MOV32I swapBuf, 4x<szShareF + szShareI>;

// readFs  = (((tid & 16) >> 3) | (tid & 1)) << 4;
--:-:-:-:1      LOP.AND tid1,   tid,    1;
--:-:-:-:1      LOP.AND readFs, tid,    16;
--:-:-:-:1      SHR.U32 readFs, readFs, 3;
--:-:-:-:1      LOP.OR  readFs, readFs, tid1;
--:-:-:-:1      SHL     readFs, readFs, 4;

// readIs = ((tid & 32) >> 1) | ((tid >> 1) & 7) << 4
--:-:-:-:1      LOP.AND tid32,   tid,   32;
--:-:-:-:1      SHR.U32 tid32_1, tid32, 1;
--:-:-:-:1      BFE.U32 tid1_7,  tid,   0x301; // 3 bits at position 1
--:-:-:-:1      LOP.OR  readIs, tid1_7, tid32_1;
--:-:-:-:1      ISCADD  readIs, readIs, 4x<szShareF>, 4;
</SCHEDULE_BLOCK>

[+ load_lut() +]

--:-:2:-:1  @P1 LDG.E.CI.64 loadF, [trackF];
--:-:5:-:1 @!P1 LDS.U.64    loadF, [addr_zero];

--:-:3:-:1  @P1 LDG.E.128 load0I, [trackI + 2x<00>];
--:-:4:-:1  @P1 LDG.E.128 load1I, [trackI + 2x<64>];
--:-:-:-:1 @!P1 LDS.U.128 load0I, [addr_zero];
--:-:6:-:1 @!P1 LDS.U.128 load1I, [addr_zero];

12:-:-:-:1      [+ convert() +] storeF3, loadF1.H1;
--:-:-:-:1      [+ convert() +] storeF2, loadF1.H0;
--:-:-:-:1      [+ convert() +] storeF1, loadF0.H1;
--:-:2:-:2      [+ convert() +] storeF0, loadF0.H0;

02:1:-:-:2      STS.128 [writeFs], storeF0;

25:-:-:-:1      [+ convert() +] store0I7, load0I3.H1;
--:-:-:-:1      [+ convert() +] store0I6, load0I3.H0;
--:-:-:-:1      [+ convert() +] store0I5, load0I2.H1;
--:-:2:-:1      [+ convert() +] store0I4, load0I2.H0;
--:-:-:-:1      [+ convert() +] store0I3, load0I1.H1;
--:-:-:-:1      [+ convert() +] store0I2, load0I1.H0;
--:-:-:-:1      [+ convert() +] store0I1, load0I0.H1;
--:-:3:-:1      [+ convert() +] store0I0, load0I0.H0;

02:-:-:-:1      STS.128 [writeIs + 4x<32>], store0I4;
04:1:-:-:2      STS.128 [writeIs + 4x<00>], store0I0;

09:-:-:-:1      [+ convert() +] store1I7, load1I3.H1;
--:-:-:-:1      [+ convert() +] store1I6, load1I3.H0;
--:-:-:-:1      [+ convert() +] store1I5, load1I2.H1;
--:-:2:-:1      [+ convert() +] store1I4, load1I2.H0;
--:-:-:-:1      [+ convert() +] store1I3, load1I1.H1;
--:-:-:-:1      [+ convert() +] store1I2, load1I1.H0;
--:-:-:-:1      [+ convert() +] store1I1, load1I0.H1;
--:-:3:-:1      [+ convert() +] store1I0, load1I0.H0;

02:-:-:-:1      STS.128 [writeIs + 4x<96>], store1I4;
04:1:-:-:1      STS.128 [writeIs + 4x<64>], store1I0;

[+ loop_setup() +]

--:-:2:-:2  @P1 LDG.E.CI.64 loadF,  [trackF];
--:-:3:-:1  @P1 LDG.E.128   load0I, [trackI + 2x<00>];
--:5:4:-:1  @P1 LDG.E.128   load1I, [trackI + 2x<64>];

[-
    our $convert;
    our %insert =
    (
        j0c1  => "--:-:-:-:1      ISETP.GE.AND P1, PT, posCRST,  RZ, PT;\n",
        j0c3  => "--:-:-:-:1      ISETP.GE.AND P0, PT, posCRST, -8, PT;\n",

        j0c13 => "--:-:6:-:1  \@P1 I2F.F32.S32 posCRSTf, posCRST;\n",

        j0c39 => "20:-:-:-:1  \@P1 FMUL channel, posCRSTf, lutSizeRcp;\n",
        j0c44 => "--:-:-:-:1  \@P1 FFMA channel, channel, 5.9604644775390625e-08, channel;\n",
        j0c46 => "--:-:6:-:1  \@P1 F2I.S32.F32.TRUNC channel, channel;\n",

        j1c8  => "20:-:-:-:1  \@P1 VMAD.U16.U16 lutOffset, -channel, lutSize, posCRST;\n",
        j1c13 => "--:-:-:-:1  \@P1 SHL lutOffset, lutOffset, 3;\n",

        j1c33 => "02:-:-:-:1  \@P0 $convert storeF3, loadF1.H1;\n",
        j1c37 => "--:-:-:-:1  \@P0 $convert storeF2, loadF1.H0;\n",
        j1c41 => "--:-:-:-:1  \@P0 $convert storeF1, loadF0.H1;\n",
        j1c45 => "--:-:2:-:1  \@P0 $convert storeF0, loadF0.H0;\n",

        j1c60 => "02:-:-:-:1  \@P0 STS.128 [writeFs], storeF0;\n",

        j1c62 => "--:-:2:-:1  \@P1 LDS.U.64 sliceIF, [lutOffset + addr_lut];\n",

        j2c10 => "--:-:-:-:1  \@P1 XMAD     offsetFc, channel, param_KRST, RZ;\n",
        j2c15 => "--:-:-:-:1  \@P1 XMAD     offsetIc, channel, param_DHWN,    RZ;\n",
        j2c20 => "--:-:-:-:1  \@P1 XMAD.PSL offsetIc, channel, param_DHWN.H1, offsetIc;\n",
        j2c22 => "--:-:-:-:1      IADD posCRST, posCRST, -8;\n",

        j2c29 => "02:-:-:-:1  \@P1 IADD3 offsetF, offsetFk, offsetFc, sliceF;\n",
        j2c34 => "--:-:-:-:1  \@P1 LEA      trackF0.CC, offsetF, param_F[0],     1;\n",
        j2c36 => "--:-:-:-:1  \@P1 IADD3 offsetI, offsetIn, offsetIc, sliceI;\n",
        j2c38 => "--:-:-:-:1  \@P1 LEA.HI.X trackF1,    offsetF, param_F[1], RZ, 1;\n",

        j2c40 => "--:-:2:-:1  \@P1 LDG.E.CI.64 loadF0, [trackF];\n",


        j3c29 => "04:-:-:-:1  \@P0 $convert store0I7, load0I3.H1;\n",
        j3c33 => "--:-:-:-:1  \@P0 $convert store0I6, load0I3.H0;\n",
        j3c37 => "--:-:-:-:1  \@P0 $convert store0I5, load0I2.H1;\n",
        j3c41 => "--:-:6:-:1  \@P0 $convert store0I4, load0I2.H0;\n",
        j3c45 => "--:-:-:-:1  \@P0 $convert store0I3, load0I1.H1;\n",
        j3c49 => "--:-:-:-:1  \@P0 $convert store0I2, load0I1.H0;\n",
        j3c53 => "--:-:-:-:1  \@P0 $convert store0I1, load0I0.H1;\n",
        j3c57 => "--:-:3:-:1  \@P0 $convert store0I0, load0I0.H0;\n",

        j3c59 => "20:-:-:-:1  \@P0 STS.128 [writeIs + 4x<32>], store0I4;\n",
        j4c8  => "04:3:-:-:1  \@P0 STS.128 [writeIs + 4x<00>], store0I0;\n",

        j4c50 => "10:-:-:-:1  \@P1 LEA      trackI0.CC, offsetI, param_I[0],     1;\n",
        j4c55 => "--:-:-:-:1  \@P1 LEA.HI.X trackI1,    offsetI, param_I[1], RZ, 1;\n",

        j4c61 => "04:-:3:-:1  \@P1 LDG.E.128 load0I0, [trackI + 2x<00>];\n",


        j5c29 => "08:-:-:-:1  \@P0 $convert store1I7, load1I3.H1;\n",
        j5c33 => "--:-:-:-:1  \@P0 $convert store1I6, load1I3.H0;\n",
        j5c37 => "--:-:-:-:1  \@P0 $convert store1I5, load1I2.H1;\n",
        j5c41 => "--:-:6:-:1  \@P0 $convert store1I4, load1I2.H0;\n",
        j5c45 => "--:-:-:-:1  \@P0 $convert store1I3, load1I1.H1;\n",
        j5c49 => "--:-:-:-:1  \@P0 $convert store1I2, load1I1.H0;\n",
        j5c53 => "--:-:-:-:1  \@P0 $convert store1I1, load1I0.H1;\n",
        j5c57 => "--:-:4:-:1  \@P0 $convert store1I0, load1I0.H0;\n",

        j5c59 => "20:-:-:-:1  \@P0 STS.128 [writeIs + 4x<96>], store1I4;\n",
        j6c8  => "08:4:-:-:1  \@P0 STS.128 [writeIs + 4x<64>], store1I0;\n",

        j6c61 => "08:5:4:-:1  \@P1 LDG.E.128 load1I0, [trackI + 2x<64>];\n",

        j6c63   => "--:-:-:-:5  \@P0 BAR.SYNC 0;\n" .
                   "--:-:-:-:1  \@P0 IADD readIs,  readIs, -swapBuf;\n" .
                   "--:-:-:-:1  \@P0 IADD readFs,  readFs, -swapBuf;\n" .
                   "--:-:-:-:1  \@P0 IADD writeIs, writeIs, swapBuf;\n" .
                   "--:-:-:-:1  \@P0 IADD writeFs, writeFs, swapBuf;\n" .
                   "--:-:-:-:1  \@P0 IADD swapBuf, RZ,     -swapBuf;\n",

        j7c63 => "--:-:-:Y:5  \@P0 BRA.U LOOP;\n",
    );
-]

LOOP:

[+ main_loop() +]

--:-:1:-:1      LDS.U.128 mpqk, [addr_mpqk];

<SCHEDULE_BLOCK>

// tidOX = (tid & 7) << 3 + (tid & 32) << 1
// tidOY = (tid & 31) >> 3
--:-:-:-:1      SHL     tid32,  tid32, 1;
--:-:-:-:1      ISCADD  tidOX,  tid7,  tid32, 3;
--:-:-:-:1      LOP.AND tidOY,  tid,   31;
--:-:-:-:1      SHR.U32 tidOY,  tidOY, 3;

--:-:-:-:1      ISETP.GT.AND P2, PT, swapBuf, RZ, PT;
--:-:-:-:1  @P2 IADD readFs,  readFs, -swapBuf;

// readIs = ((tid & 32) >> 1) | (((tid >> 1) & 7) << 1) << 4
--:-:-:-:1      ISCADD readIs, tid1_7, tid32_1, 1;
--:-:-:-:1      SHL    readIs, readIs, 4;

// Div by 4 here collapses k stride
// writeCs = (readFs / 4) * 128 + readIs;
--:-:-:-:1      ISCADD  writeCs, readFs, readIs, 5;

// readCs  = 4 * (tidOX + (tidOY * 128))
--:-:-:-:1      ISCADD readCs, tidOY, tidOX, 7;
--:-:-:-:1      SHL    readCs, readCs, 2;

// n = blkI*128 + tidOX;
--:-:-:-:1      ISCADD n, idx_N, tidOX, 7;

// Mul by 4 here expands k stride back out
// k = blkF*32 + tidOY * 4
--:-:-:-:1      SHL    tidOY,   tidOY, 2;
--:-:-:-:1      ISCADD k, idx_K, tidOY, 5;


[+ output_setup(63, 1, 6) +]

</SCHEDULE_BLOCK>

[+ output() +]
